import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np

class BaseDefense(nn.Module):
    """
    Base class for all defenses.
    """

    def __init__(self, device = None, iterations = 1):
        super(BaseDefense, self).__init__()
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device
        try:
            self.iterations = int(iterations)
        except:
            raise ValueError("Iterations must be an integer.")
        if self.iterations < 0:
            raise ValueError("Iterations cant be a negative number.")
        self.normalized = False
        self.device_attributes = []

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        """
        return x
        #raise NotImplementedError("Defense method not implemented for base defense.")
    
    def forward(self, x):
        """
        Forward pass of the defense.
        """
        if self.normalized:
            x = self._denormalize(x)
        x = self._defense(x)
        if self.normalized:
            x = self._normalize(x)
        return x
        

    def set_normalization_used(self, mean, std):
        """
        Set the normalization parameters for the defense.
        """
        if isinstance(mean, tuple):
            mean = torch.tensor(mean)
        if isinstance(std, tuple):
            std = torch.tensor(std)
        self.mean = mean.to(self.device)
        self.std = std.to(self.device)
        self.device_attributes.append('mean')
        self.device_attributes.append('std')
        self.normalized = True

    def _denormalize(self,x):
        """
        Denormalize the input tensor.
        """
        x = x * self.std.view(1,-1,1,1) + self.mean.view(1,-1,1,1)
        return x
    
    def _normalize(self,x):
        """
        Normalize the input tensor.
        """
        x = transforms.Normalize(self.mean,self.std)(x)
        return x
    
    def to(self, device):
        """
        Move the defense to the specified device.
        """
        super(BaseDefense, self).to(device)
        self.device = device
        for attr in self.device_attributes:
            self._change_device(attr)

    def _change_device(self,attr):
        """
        Change the device of an attribute.
        """
        if hasattr(self, attr):
            attr_value = getattr(self, attr)
            setattr(self, attr, attr_value.to(self.device))

    def _save_images(self, x, path,j =0):
        """
        Save the images to the specified path.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
            path (str): Path to save the images.
        """
        batch_size = x.shape[0]

        
        x = self.forward(x.to(self.device))
        x = x.cpu().detach().numpy()
        
        for i in range(x.shape[0]):
            img = x[i]
            img = img.transpose(1, 2, 0)
            img = (img * 255).astype('uint8')
            Image.fromarray(img).save(f"{path}/image_{j*batch_size + i}.png")

    def get_distance(self, x,original,norms=[np.inf]):
        """
        Get the distance of the input tensor to the transformed tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
            norm (int): Norm to use for distance calculation.
        
        Returns:
            torch.Tensor: Distance of the input tensor.
        """

        x_hat = self.forward(x)
        
        try:
            return [torch.linalg.vector_norm(x_hat - original, ord=norm,dim=[1,2,3]).detach().cpu() for norm in norms]
        except:
            raise ValueError("Norm must be a valid integer or string.")
        

    def compute_fid(self,x):
        """
        Compute the FID score of the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            float: FID score of the input tensor.
        """
        # Placeholder for FID computation
        # Implement FID computation here
        pass
        

        
        

    